BayesianSurvival.jl
  • Overview

On this page

  • Survival analysis - what’s that?
  • Why do this? Why reimplement things?
  • Simulation
    • pem_survival_model
    • pem_survival_model_randomwalk
    • pem_survival_model_timevarying
  • Addendum / Disclaimer

Reproducing example models from survivalstan

Author

Nikolas Siccha

Survival analysis - what’s that?

According to Wikipedia:

Survival analysis is a branch of statistics for analyzing the expected duration of time until one event occurs, such as death in biological organisms and failure in mechanical systems.

We’ll consider the setting used for the examples at https://jburos.github.io/survivalstan/Examples.html. We will have a model which, for a set of persons, takes

  • a set of covariates (age and gender per person),
  • a list of times at which the event either occurs, or until which the event did not occur (one time and event/survival indicator per person),

and, after following standard Bayesian procedures via conditioning on observations, yields a way to predict the survival time of unobserved persons, given the same covariates.

For fixed covariates \(x\) and model parameters \(\theta\), the models below will give us a way to compute a (piecewise exponential) survival function \(S(t) = Pr(T > t)\), i.e. a function which models the probability that the event in question has not occured until the specified time \(t\). Usually as well as in our setting, the survival function will be the solution to a simple linear first-order differential equation with variable coefficients, concretely we have

\[ S'(t) = -\lambda(t)S(t)\quad\text{and}\quad S(0) = 1 \] where the hazard function/rate \(\lambda(t)\) is a non-negative function, such that \(S(t)\) is monotonically non-increasing and has values in \((0, 1]\). The log of the survival function is then \[ \log S(t) = -\int_0^t\lambda(\tau) d\tau. \]

As \(S(t)\) models the survival (the non-occurence of an event), the log likelihood of the occurence of an event at a given time \(t\) is \[ \log p_1(t) = \log -S'(t) = \log \lambda(t) + \log S(t) = \log \lambda(t) -\int_0^t\lambda(\tau) d\tau \] and the log likelihood of survival up to at least time \(t\) is \[ \log p_0(t) = \log S(t) = -\int_0^t\lambda(\tau) d\tau. \]

The first term (\(p_1(t)\)) will have to be used for the likelihood contribution of observations of the event occuring (survival up to exactly time \(t\)), while the second term (\(p_0(t)\)) will have to be used for the likelihood contribution of observations of the event not ocurring until the end of the observation time, aka as censored observations.

If the hazard function \(\lambda(\tau)\) is constant and if we do not care about constant terms (as e.g. during MCMC) we can use the Poisson distribution to compute the appropriate terms “automatically”. For piecewise constant hazard functions, it’s possible to chain individual Poisson likelihoods to compute the overall likelihood (modulo a constant term).

For piecewise constant hazard functions of the form \[ \lambda(t) = \begin{cases} \lambda_1 & \text{if } t \in [t_0, t_1],\\ \lambda_2 & \text{if } t \in (t_1, t_2],\\ \dots \end{cases} \] with \(0 = t_0 < t_1 < t_2 < \dots\) the survival function can be directly computed as \[ \log S(t_j) = -\sum_{i=1}^j (t_i-t_{i-1}) \lambda_i. \]

Why do this? Why reimplement things?

Out of curiosity, to figure out whether. and to demonstrate that I understand survival analysis. Writing down the math is nice and all, but to get correct simulation results, every little detail has to be right. At least in principle, in practice the simulation can still be subtly wrong due to errors which don’t crash everything, but only e.g. introduce biases.

Simulation

Simulated data

To simulate the data, we generate (for 100 persons)

  • the age from a Poisson distribution with mean 55,
  • the gender (male or not) from a Bernoulli distribution with mean 1/2,
  • assume a constant (in time) hazard function, computed from age and male as log(hazard) = -3 + .5 * male,
  • draw true survival times true_t from an Exponential distribution with rate parameter hazard,
  • cap them at a censor_time of 20, i.e. t = min(true_t, censor_time), and
  • set survived to true if true_t > censor_time and false otherwise.
  • Dataframe
  • Code
100×7 DataFrame
75 rows omitted
Row age male rate true_t t survived idx
Int64 Bool Float64 Float64 Float64 Bool Int64
1 51 false 0.0497871 20.672 20.0 true 1
2 51 false 0.0497871 24.6577 20.0 true 2
3 50 false 0.0497871 31.853 20.0 true 3
4 55 false 0.0497871 9.7404 9.7404 false 4
5 46 false 0.0497871 15.3396 15.3396 false 5
6 45 false 0.0497871 41.6487 20.0 true 6
7 56 false 0.0497871 12.7356 12.7356 false 7
8 47 true 0.082085 0.761362 0.761362 false 8
9 61 false 0.0497871 6.27149 6.27149 false 9
10 57 false 0.0497871 16.5202 16.5202 false 10
11 54 true 0.082085 17.3104 17.3104 false 11
12 49 false 0.0497871 0.787504 0.787504 false 12
13 56 true 0.082085 2.42708 2.42708 false 13
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
89 59 false 0.0497871 21.8631 20.0 true 89
90 50 true 0.082085 3.82811 3.82811 false 90
91 55 true 0.082085 12.3331 12.3331 false 91
92 56 true 0.082085 29.3846 20.0 true 92
93 47 true 0.082085 18.4165 18.4165 false 93
94 51 false 0.0497871 25.4902 20.0 true 94
95 65 false 0.0497871 21.1051 20.0 true 95
96 67 true 0.082085 19.4895 19.4895 false 96
97 62 false 0.0497871 8.24189 8.24189 false 97
98 65 false 0.0497871 31.491 20.0 true 98
99 53 true 0.082085 6.39705 6.39705 false 99
100 67 true 0.082085 6.19111 6.19111 false 100

Currently, the used formula/rate_form is hardcoded to match the examples.

sim_data_exp_correlated(rng=Random.default_rng(); N, censor_time, rate_form, rate_coefs) = begin 
    idx = 1:N
    age = rand(rng, Poisson(55), N)
    male = rand(rng, Bernoulli(.5), N)
    rate = @. exp(rate_coefs[1] + male * rate_coefs[2])
    true_t = rand.(rng, ConstantExponentialModel.(rate))
    t = min.(true_t, censor_time)
    survived = true_t .> censor_time
    DataFrame((;age, male, rate, true_t, t, survived, idx))
end

For all simulations,

  • we model the hazard function \(\lambda_i(t)\) of person \(i = 1,\dots,100\) to be piecewise constant, with as many pieces as there are unique event times, plus a final one which goes from the largest event observation time to the censor,
  • every person’s hazard function is unique (provided the covariates are unique),
  • the personwise (\(i\)) and timeslabwise (\(j\)) hazard values will be of the form \[ \log\lambda_{i,j} = \log a + \log\kappa_j + \langle{}X_i,\beta_j\rangle{}, \] where \(\log a\) is a scalar intercept, \(\log\kappa_j\) is a time-varying (but person-constant) effect, \(X_i\) are the \(i\)-th person’s covariates, and \(\beta_j\) are the potentially time-varying covariate effects (in timeslab \(j\)). For the first two models, \(\beta\) will be constant, while it will vary for the last model.

pem_survival_model

  • Discussion
  • Posterior parameter and predictive plots
  • Reimplemented model
  • Original model

The easiest model. The covariate effects are constant (\(\beta_1=\beta_2=\dots\)) and the time-varying (but person-constant) effect \(\log\kappa_j\) has a hierarchical normal prior with mean 0 and unkown scale (with standard half-normal prior). There seems to be small mistake in the original model, where at line 42 (AFAICT) log_t_dur = log(t_obs) assign the logarithm of the event time to the variable which has to contain the logarithm of the timeslab width.

function pem_survival_model(;
    survived,
    t,
    design_matrix,
    likelihood=true
)
    (;
        n_persons, n_covariates, t1, n_timepoints, end_idxs, t0, dt, log_dts
    ) = prepare_survival(;t, design_matrix)
    StanBlocks.@stan begin 
        @parameters begin 
            log_hazard_intercept::real
            beta::vector[n_covariates]
            log_hazard_timewise_scale::real(lower=0)
            log_hazard_timewise::vector[n_timepoints]
        end
        log_hazard_personwise = design_matrix*beta
        StanBlocks.@model @views begin 
            log_hazard_intercept ~ normal(0, 1)
            beta ~ cauchy(0, 2)
            log_hazard_timewise_scale ~ normal(0, 1)
            log_hazard_timewise ~ normal(0, log_hazard_timewise_scale)
            log_lik = Base.broadcast(1:n_persons) do person 
                idxs = 1:end_idxs[person]
                survival_lpdf(
                    survived[person], 
                    StanBlocks.@broadcasted(log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[idxs]),
                    log_dts[idxs]
                )
            end
            likelihood && (target += sum(log_lik))
        end
        StanBlocks.@generated_quantities begin
            log_lik = collect(log_lik)
            t_pred = map(1:n_persons) do person 
                for timepoint in 1:n_timepoints
                    log_hazard = log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[timepoint]
                    rv = rand(Exponential(exp(-log_hazard)))
                    rv <= dt[timepoint] && return t0[timepoint] + rv
                end
                t1[end]
            end
        end
    end
end
/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates

 // main data matrix (per observed timepoint*record)
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]

 // timepoint-specific data (per timepoint, ordered by timepoint id)
 t_obs      = observed time since origin for each timepoint id (end of period)
 t_dur      = duration of each timepoint period (first diff of t_obs)

*/
// Jacqueline Buros Novik <jackinovik@gmail.com>

data {
  // dimensions
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;

  // data matrix
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars

  // timepoint data
  vector<lower=0>[T] t_obs;
  vector<lower=0>[T] t_dur;
}
transformed data {
  vector[T] log_t_dur;  // log-duration for each timepoint
  int n_trans[S, T];

  log_t_dur = log(t_obs);

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;
  }

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
          }
      }
  }
}
parameters {
  vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
  vector[M] beta;         // beta for each covariate
  real<lower=0> baseline_sigma;
  real log_baseline_mu;
}
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;     // unstructured baseline hazard for each timepoint t

  log_baseline = log_baseline_mu + log_baseline_raw + log_t_dur;

  for (n in 1:N) {
    log_hazard[n] = log_baseline[t[n]] + x[n,]*beta;
  }
}
model {
  beta ~ cauchy(0, 2);
  event ~ poisson_log(log_hazard);
  log_baseline_mu ~ normal(0, 1);
  baseline_sigma ~ normal(0, 1);
  log_baseline_raw ~ normal(0, baseline_sigma);
}
generated quantities {
  real log_lik[N];
  vector[T] baseline;
  real y_hat_time[S];      // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)

  // compute raw baseline hazard, for summary/plotting
  baseline = exp(log_baseline_mu + log_baseline_raw);

  // prepare log_lik for loo-psis
  for (n in 1:N) {
      log_lik[n] = poisson_log_log(event[n], log_hazard[n]);
  }

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              int n;
              int pred_y;
              real log_haz;

              // determine predicted value of this sample's hazard
              n = n_trans[samp, tp];
              log_haz = log_baseline[tp] + x[n,] * beta;

              // now, make posterior prediction of an event at this tp
              if (log_haz < log(pow(2, 30)))
                  pred_y = poisson_log_rng(log_haz);
              else
                  pred_y = 9;

              // summarize survival time (observed) for this pt
              if (pred_y >= 1) {
                  // mark this patient as ineligible for future tps
                  // note: deliberately treat 9s as events
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;
              }

          }
      } // end per-timepoint loop

      // if patient still alive at max
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
      }
  } // end per-sample loop
}

pem_survival_model_randomwalk

  • Discussion
  • Posterior parameter and predictive plots
  • Reimplemented model
  • Original model

Identical to the first model, except that the time-varying (but person-constant) effect \(\log\kappa_j\) should have a “random walk” prior. AFAICT, the original model has the same small mistake as the first one (this time at line 43), but IMO some (minor) other things goes “wrong” in constructing the “random walk” prior, or rather, I believe that instead of a random walk prior as implemented in the original code, an approximate Brownian motion / Wiener process prior would have been a better choice:

A random walk prior as implemented in the original code will imply different priors for different numbers of persons and also for different realizations of the event times, while an approximate Wiener process prior does not (or rather, much less). Consider the following:

(Gaussian) random walk prior

For random walk parameters \(x_1, x_2, \dots\) with scale parameter \(\sigma\), the (conditional) prior density is \[ p(x_i | x_{i-1}) = p_\mathcal{N}(x_i | x_{i-1}, \sigma^2) \text{ for } i=1,2,\dots \] and with \(x_0\) another parameter with appropriate prior.

Approximate (Gaussian) Wiener process prior

Following Wikipedia:

The Wiener process \(W_t\) is characterised by the following properties: […] W has Gaussian increments: […] \(W_{t+i} - W_t \sim \mathcal{N}(0,u)\).

I.e., for timepoints \(0 = t_0 < t_1 < t_2 < \dots\) as above, the (conditional) prior density of the (shifted) Wiener process values \(x_1, x_2, \dots\) with scale parameter \(\sigma\) is \[ p(x_i | x_{i-1}) = p_\mathcal{N}(x_i | x_{i-1}, (t_i-t_{i-1})\sigma^2) \text{ for } i=1,2,\dots \] and with \(x_0\) as before.

Dependence on the observed event times

The difference between the two priors will become most easily apparent by looking at the implied prior on the (log) hazard at (or right before) the censor time \(t_\text{censor} = t_{N+1}\), for varying numbers of unique observed event times \(N\). For the random walk prior, we’ll have \[ x_j \sim \mathcal{N}(x_0, j\sigma^2) \text{ for } j = 1,\dots,N+1, \] while for the Wiener process prior, we’ll have \[ x_j \sim \mathcal{N}(0, t_j\sigma^2) \text{ for } j = 1,\dots,N+1. \] In particular, for \(j=N+1\) (i.e. at censor time), we get a constant prior distribution for the Wiener process prior, but for the random walk prior we get a prior distribution that depends on the number of unique observed event times \(N\). Similarly, even for fixed \(N\), there is a (potentially strong) dependence of the implied prior for “interior” time slabs on the realization of the even times for the random walk prior, while there’s “no” dependence of the implied prior for the Wiener process prior. Caveat: There will actually be a dependence of the implied prior on the event time realizations also for the Wiener process, but this is only due to the piecewise-constant “assumption” and can be interpreted as an approximation error to the solution of the underlying stochastic differential equation.

function pem_survival_model_randomwalk(;
    survived,
    t,
    design_matrix,
    likelihood=true
)
    (;
        n_persons, n_covariates, t1, n_timepoints, end_idxs, t0, dt, log_dts
    ) = prepare_survival(;t, design_matrix)
    rw_sqrt_scale = @. sqrt(.5*(dt[1:end-1] + dt[2:end]))
    StanBlocks.@stan begin 
        @parameters begin 
            log_hazard_intercept::real
            beta::vector[n_covariates]
            log_hazard_timewise_scale::real(lower=0)
            log_hazard_timewise::vector[n_timepoints]
        end
        log_hazard_personwise = design_matrix*beta
        StanBlocks.@model @views begin 
            log_hazard_intercept ~ normal(0, 1)
            beta ~ cauchy(0, 2)
            log_hazard_timewise_scale ~ normal(0, 1)
            log_hazard_timewise[1] ~ normal(0, 1)
            log_hazard_timewise ~ random_walk(
                StanBlocks.@broadcasted(log_hazard_timewise_scale * rw_sqrt_scale)
            )
            log_lik = Base.broadcast(1:n_persons) do person 
                idxs = 1:end_idxs[person]
                survival_lpdf(
                    survived[person], 
                    StanBlocks.@broadcasted(log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[idxs]),
                    log_dts[idxs]
                )
            end
            likelihood && (target += sum(log_lik))
        end
        StanBlocks.@generated_quantities begin
            log_lik = collect(log_lik)
            t_pred = map(1:n_persons) do person 
                for timepoint in 1:n_timepoints
                    log_hazard = log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[timepoint]
                    rv = rand(Exponential(exp(-log_hazard)))
                    rv <= dt[timepoint] && return t0[timepoint] + rv
                end
                t1[end]
            end
        end
    end
end
/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates

 // main data matrix (per observed timepoint*record)
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]

 // timepoint-specific data (per timepoint, ordered by timepoint id)
 t_obs      = observed time since origin for each timepoint id (end of period)
 t_dur      = duration of each timepoint period (first diff of t_obs)

*/
// Jacqueline Buros Novik <jackinovik@gmail.com>


data {
  // dimensions
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;

  // data matrix
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars

  // timepoint data
  vector<lower=0>[T] t_obs;
  vector<lower=0>[T] t_dur;
}
transformed data {
  vector[T] log_t_dur;  // log-duration for each timepoint
  int n_trans[S, T];

  log_t_dur = log(t_obs);

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;
  }

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
          }
      }
  }
}
parameters {
  vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
  vector[M] beta;                      // beta for each covariate
  real<lower=0> baseline_sigma;
  real log_baseline_mu;
}
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;

  log_baseline = log_baseline_raw + log_t_dur;

  for (n in 1:N) {
    log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + x[n,]*beta;
  }
}
model {
  beta ~ cauchy(0, 2);
  event ~ poisson_log(log_hazard);
  log_baseline_mu ~ normal(0, 1);
  baseline_sigma ~ normal(0, 1);
  log_baseline_raw[1] ~ normal(0, 1);
  for (i in 2:T) {
      log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);
  }
}
generated quantities {
  real log_lik[N];
  vector[T] baseline;
  int y_hat_mat[S, T];     // ppcheck for each S*T combination
  real y_hat_time[S];      // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)

  // compute raw baseline hazard, for summary/plotting
  baseline = exp(log_baseline_raw);

  for (n in 1:N) {
      log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);
  }

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              int n;
              int pred_y;
              real log_haz;

              // determine predicted value of y
              // (need to recalc so that carried-forward data use sim tp and not t[n])
              n = n_trans[samp, tp];
              log_haz = log_baseline_mu + log_baseline[tp] + x[n,]*beta;
              if (log_haz < log(pow(2, 30)))
                  pred_y = poisson_log_rng(log_haz);
              else
                  pred_y = 9;

              // mark this patient as ineligible for future tps
              // note: deliberately make 9s ineligible
              if (pred_y >= 1) {
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;
              }

              // save predicted value of y to matrix
              y_hat_mat[samp, tp] = pred_y;
          }
          else if (sample_alive == 0) {
              y_hat_mat[samp, tp] = 9;
          }
      } // end per-timepoint loop

      // if patient still alive at max
      //
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
      }
  } // end per-sample loop
}

pem_survival_model_timevarying

  • Discussion
  • Posterior parameter and predictive plots
  • Reimplemented model
  • Original model

To be finished. To keep things short:

  • The original model has the same minor problems as the other models.
  • While the original model implements a random walk prior on the increments of the covariate effects, I’ve kept things a bit simpler and instead just implemented the corresponding Wiener process prior on the values of the covariate effects. IMO, putting a given prior on the increments instead of on the values or vice versa is a modeling decision, and not a “mistake” by any stretch of the imagination. Doing one or the other implies different things, and which choice is “better” is not clear a priori and may depend on the setting.
  • I believe sampling may have failed a bit for the run included in this notebook. I believe I have seen better sampling “runs”, but as this doesn’t have to be perfect, I’ve left it as is.
function pem_survival_model_timevarying(;
    survived,
    t,
    design_matrix,
    likelihood=true
)
    (;
        n_persons, n_covariates, t1, n_timepoints, end_idxs, t0, dt, log_dts
    ) = prepare_survival(;t, design_matrix)
    rw_sqrt_scale = @. sqrt(.5*(dt[1:end-1] + dt[2:end]))
    StanBlocks.@stan begin 
        @parameters begin 
            log_hazard_intercept::real
            beta_timewise_scale::real(lower=0)
            beta_timewise::matrix[n_covariates, n_timepoints]
            log_hazard_timewise_scale::real(lower=0)
            log_hazard_timewise::vector[n_timepoints]
        end
        log_hazard_personwise = design_matrix*beta_timewise
        StanBlocks.@model @views begin 
            log_hazard_intercept ~ normal(0, 1)
            beta_timewise_scale ~ cauchy(0, 1)
            beta_timewise[:, 1] ~ cauchy(0, 1) 
            beta_timewise' ~ random_walk(
                StanBlocks.@broadcasted(beta_timewise_scale * rw_sqrt_scale)
            )
            log_hazard_timewise_scale ~ normal(0, 1)
            log_hazard_timewise[1] ~ normal(0, 1)
            log_hazard_timewise ~ random_walk(
                StanBlocks.@broadcasted(log_hazard_timewise_scale * rw_sqrt_scale)
            )
            log_lik = Base.broadcast(1:n_persons) do person 
                idxs = 1:end_idxs[person]
                survival_lpdf(
                    survived[person], 
                    StanBlocks.@broadcasted(log_hazard_intercept + log_hazard_personwise[person, idxs] + log_hazard_timewise[idxs]),
                    log_dts[idxs]
                )
            end
            likelihood && (target += sum(log_lik))
        end
        StanBlocks.@generated_quantities begin
            log_lik = collect(log_lik)
            t_pred = map(1:n_persons) do person 
                for timepoint in 1:n_timepoints
                    log_hazard = log_hazard_intercept + log_hazard_personwise[person, timepoint] + log_hazard_timewise[timepoint]
                    rv = rand(Exponential(exp(-log_hazard)))
                    rv <= dt[timepoint] && return t0[timepoint] + rv
                end
                t1[end]
            end
        end
    end
end
/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates

 // data
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]
 obs_t      = observed end time for interval for timepoint for that obs

*/
// Jacqueline Buros Novik <jackinovik@gmail.com>

functions {
  matrix spline(vector x, int N, int H, vector xi, int P) {
    matrix[N, H + P] b_x;         // expanded predictors
    for (n in 1:N) {
        for (p in 1:P) {
            b_x[n,p] <- pow(x[n],p-1);  // x[n]^(p-1)
        }
        for (h in 1:H)
          b_x[n, h + P] <- fmax(0, pow(x[n] - xi[h],P-1));
    }
    return b_x;
  }
}
data {
  // dimensions
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;

  // data matrix
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars

  // timepoint data
  vector<lower=0>[T] t_obs;
  vector<lower=0>[T] t_dur;
}
transformed data {
  vector[T] log_t_dur;
  int n_trans[S, T];

  log_t_dur = log(t_obs);

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;
  }

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
          }
      }
  }
}
parameters {
  vector[T] log_baseline_raw;    // unstructured baseline hazard for each timepoint t
  real<lower=0> baseline_sigma;
  real log_baseline_mu;

  vector[M] beta; // beta-intercept
  vector<lower=0>[M] beta_time_sigma;
  vector[T-1] raw_beta_time_deltas[M]; // for each coefficient
                                       // change in coefficient value from previous time
}
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;
  vector[T] beta_time[M];
  vector[T] beta_time_deltas[M];

  // adjust baseline hazard for duration of each period
  log_baseline = log_baseline_raw + log_t_dur;

  // compute timepoint-specific betas
  // offsets from previous time
  for (coef in 1:M) {
      beta_time_deltas[coef][1] = 0;
      for (time in 2:T) {
          beta_time_deltas[coef][time] = raw_beta_time_deltas[coef][time-1];
      }
  }

  // coefficients for each timepoint T
  for (coef in 1:M) {
      beta_time[coef] = beta[coef] + cumulative_sum(beta_time_deltas[coef]);
  }

  // compute log-hazard for each obs
  for (n in 1:N) {
    real log_linpred;
    log_linpred <- 0;
    for (coef in 1:M) {
      // for now, handle each coef separately
      // (to be sure we pull out the "right" beta..)
      log_linpred = log_linpred + x[n, coef] * beta_time[coef][t[n]];
    }
    log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + log_linpred;
  }
}
model {
  // priors on time-varying coefficients
  for (m in 1:M) {
    raw_beta_time_deltas[m][1] ~ normal(0, 100);
    for(i in 2:(T-1)){
        raw_beta_time_deltas[m][i] ~ normal(raw_beta_time_deltas[m][i-1], beta_time_sigma[m]);
    }
  }
  beta_time_sigma ~ cauchy(0, 1);
  beta ~ cauchy(0, 1);

  // priors on baseline hazard
  log_baseline_mu ~ normal(0, 1);
  baseline_sigma ~ normal(0, 1);
  log_baseline_raw[1] ~ normal(0, 1);
  for (i in 2:T) {
      log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);
  }

  // model
  event ~ poisson_log(log_hazard);
}
generated quantities {
  real log_lik[N];
  vector[T] baseline;
  int y_hat_mat[S, T];     // ppcheck for each S*T combination
  real y_hat_time[S];      // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)

  // compute raw baseline hazard, for summary/plotting
  baseline = exp(log_baseline_raw);

  // log_likelihood for loo-psis
  for (n in 1:N) {
      log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);
  }

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              int n;
              int pred_y;
              real log_linpred;
              real log_haz;

              // determine predicted value of y
              n = n_trans[samp, tp];

              // (borrow code from above to calc linpred)
              // but use sim tp not t[n]
              log_linpred = 0;
              for (coef in 1:M) {
                  // for now, handle each coef separately
                  // (to be sure we pull out the "right" beta..)
                  log_linpred = log_linpred + x[n, coef] * beta_time[coef][tp];
              }
              log_haz = log_baseline_mu + log_baseline[tp] + log_linpred;

              // now, make posterior prediction
              if (log_haz < log(pow(2, 30)))
                  pred_y = poisson_log_rng(log_haz);
              else
                  pred_y = 9;

              // mark this patient as ineligible for future tps
              // note: deliberately make 9s ineligible
              if (pred_y >= 1) {
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;
              }

              // save predicted value of y to matrix
              y_hat_mat[samp, tp] = pred_y;
          }
          else if (sample_alive == 0) {
              y_hat_mat[samp, tp] = 9;
          }
      } // end per-timepoint loop

      // if patient still alive at max
      //
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
      }
  } // end per-sample loop
}

Addendum / Disclaimer

  • I am aware that survivalstan hasn’t been updated in the last 7 years (according to github). I have not implemented the above models to unearth any errors or write a competitor. I believe but haven’t checked, that the “actual” models used by survivalstan are “more” correct. I was mainly curious whether I could do it, and I wanted to see how well StanBlocks.jl does.
  • I’ve skipped the pem_survival_model_gamma model showcased at https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html because I did not understand why the widths of the timeslabs should affect the shape parameter of the Gamma prior. Only after implementing the time varying models did I discover the models at https://nbviewer.org/github/hammerlab/survivalstan/blob/master/example-notebooks/Test%20new_gamma_survival_model%20with%20simulated%20data.ipynb. Also, the “Worked examples” page lists a “User-supplied PEM survival model with gammahazard”, though for some reason it does not show up in the sidebar for either of the other examples, compare https://jburos.github.io/survivalstan/examples/Example-using-pem_survival_model.html, https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html, https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_randomwalk%20with%20simulated%20data.html and https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_timevarying%20with%20simulated%20data.html.
Source Code
# Reproducing example models from `survivalstan`

## Survival analysis - what's that?

According to [Wikipedia](https://en.wikipedia.org/wiki/Survival_analysis): 

> Survival analysis is a branch of statistics for analyzing the expected duration of time until one event occurs, such as death in biological organisms and failure in mechanical systems.

We'll consider the setting used for the examples at [https://jburos.github.io/survivalstan/Examples.html](https://jburos.github.io/survivalstan/Examples.html). We will have a model which, for a set of persons, takes 

* a set of covariates (age and gender per person),
* a list of times at which the event either occurs, or until which the event did not occur (one time and event/survival indicator per person),

and, after following standard Bayesian procedures via conditioning on observations, yields a way to predict the survival time of unobserved persons, given the same covariates.

For fixed covariates $x$ and model parameters $\theta$, the models below will give us a way to compute a (piecewise exponential) survival function $S(t) = Pr(T > t)$, i.e. a function which models the probability that the event in question has not occured until the specified time $t$. Usually as well as in our setting, the survival function will be the solution to a simple [linear first-order differential equation with variable coefficients](https://en.wikipedia.org/wiki/Linear_differential_equation#First-order_equation_with_variable_coefficients), concretely we have 

$$
S'(t) = -\lambda(t)S(t)\quad\text{and}\quad S(0) = 1
$$
where the hazard function/rate $\lambda(t)$ is a non-negative function, such that $S(t)$ is monotonically non-increasing and has values in $(0, 1]$. The log of the survival function is then
$$
\log S(t) = -\int_0^t\lambda(\tau) d\tau.
$$

As $S(t)$ models the survival (the non-occurence of an event), the **log likelihood of the occurence of an event at a given time $t$** is 
$$
\log p_1(t) = \log -S'(t) = \log \lambda(t) + \log S(t) =  \log \lambda(t) -\int_0^t\lambda(\tau) d\tau
$$
and the **log likelihood of survival up to at least time $t$** is
$$
\log p_0(t) = \log S(t) = -\int_0^t\lambda(\tau) d\tau.
$$

The first term ($p_1(t)$) will have to be used for the likelihood contribution of observations of the event occuring (survival up to exactly time $t$), while the second term ($p_0(t)$) will have to be used for the likelihood contribution of observations of the event not ocurring until the end of the observation time, aka as censored observations. 

If the hazard function $\lambda(\tau)$ is constant and if we do not care about constant terms (as e.g. during MCMC) we can use the [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) to compute the appropriate terms "automatically". For piecewise constant hazard functions, it's possible to chain individual Poisson likelihoods to compute the overall likelihood (modulo a constant term).


For **piecewise constant hazard functions** of the form 
$$
\lambda(t) = \begin{cases}
    \lambda_1 & \text{if } t \in [t_0, t_1],\\
    \lambda_2 & \text{if } t \in (t_1, t_2],\\
    \dots
\end{cases}
$$ 
with $0 = t_0 < t_1 < t_2 < \dots$ the survival function can be directly computed as 
$$
\log S(t_j) = -\sum_{i=1}^j (t_i-t_{i-1}) \lambda_i.
$$


## Why do *this*? Why reimplement things?

Out of curiosity, to figure out whether. and to demonstrate that I understand survival analysis. 
Writing down the math is nice and all, but to get correct simulation results, every little detail has to be right. 
At least in principle, in practice the simulation can still be subtly wrong due to errors which don't crash everything, but only e.g. introduce biases.

## Simulation

<details>
  <summary>Simulated data</summary>

To simulate the data, we generate (for 100 persons) 

* the `age` from a Poisson distribution with mean 55,
* the gender (`male` or not) from a Bernoulli distribution with mean 1/2,
* assume a constant (in time) hazard function, computed from `age` and `male` as `log(hazard) = -3 + .5 * male`,
* draw true survival times `true_t` from an Exponential distribution with rate parameter `hazard`,
* cap them at a `censor_time` of 20, i.e. `t = min(true_t, censor_time)`, and 
* set `survived` to `true` if `true_t > censor_time` and false otherwise.

::: {.panel-tabset}

### Dataframe

```{julia}
include("index.jl")  
sim = Simulation(100, 20., @formula(log_rate~1+male), [-3,.5]);
sim.df  
```
### Code

Currently, the used `formula`/`rate_form` is hardcoded to match the examples.

```{.julia include="../src/functions.jl" snippet="sim_data_exp_correlated"} 
```

:::
</details>

For all simulations, 

* we model the hazard function $\lambda_i(t)$ of person $i = 1,\dots,100$ to be piecewise constant, with as many pieces as there are unique event times, plus a final one which goes from the largest event observation time to the censor,
* every person's hazard function is unique (provided the covariates are unique),
* the personwise ($i$) and timeslabwise ($j$) hazard values will be of the form 
$$
\log\lambda_{i,j} = \log a  + \log\kappa_j + \langle{}X_i,\beta_j\rangle{},
$$
where $\log a$ is a scalar intercept, $\log\kappa_j$ is a time-varying (but person-constant) effect, $X_i$ are the $i$-th person's covariates, and $\beta_j$ are the potentially time-varying covariate effects (in timeslab $j$). For the first two models, $\beta$ will be constant, while it will vary for the last model.

### `pem_survival_model`

::: {.panel-tabset}

#### Discussion

The easiest model. The covariate effects are constant ($\beta_1=\beta_2=\dots$) and the time-varying (but person-constant) effect $\log\kappa_j$ has a hierarchical normal prior with mean 0 and unkown scale (with standard half-normal prior). There seems to be small mistake in the original model, where at line 42 (AFAICT) `log_t_dur = log(t_obs)` assign the logarithm of the event *time* to the variable which has to contain the logarithm of the timeslab width.

#### Posterior parameter and predictive plots

```{julia}
plot_summary(sim.lr1...; sim.df) 
```

#### Reimplemented model

```{.julia include="../src/models.jl" snippet="pem_survival_model"} 
```

#### Original model

```{.stan include="survivalstan/pem_survival_model.stan"}
```

:::

### `pem_survival_model_randomwalk`

::: {.panel-tabset}

#### Discussion

Identical to the first model, except that the time-varying (but person-constant) effect $\log\kappa_j$ should have a "random walk" prior. AFAICT, the original model has the same small mistake as the first one (this time at line 43), but **IMO some (minor) other things goes "wrong" in constructing the "random walk" prior, or rather, I believe that instead of a random walk prior as implemented in the original code, an approximate Brownian motion / Wiener process prior would have been a better choice:**

*A random walk prior as implemented in the original code will imply different priors for different numbers of persons and also for different realizations of the event times, while an approximate Wiener process prior does not (or rather, much less).* Consider the following:

##### (Gaussian) random walk prior

For random walk parameters $x_1, x_2, \dots$ with scale parameter $\sigma$, the (conditional) prior density is
$$
    p(x_i | x_{i-1}) = p_\mathcal{N}(x_i | x_{i-1}, \sigma^2) \text{ for } i=1,2,\dots
$$
and with $x_0$ another parameter with appropriate prior.

##### Approximate (Gaussian) Wiener process prior

Following [Wikipedia](https://en.wikipedia.org/wiki/Wiener_process):

> The Wiener process $W_t$ is characterised by the following properties: 
> [...] W has Gaussian increments: [...] $W_{t+i} - W_t \sim \mathcal{N}(0,u)$.

I.e., for timepoints $0 = t_0 < t_1 < t_2 < \dots$ as above, the (conditional) prior density of the (shifted) Wiener process
values $x_1, x_2, \dots$ with scale parameter $\sigma$ is
$$
    p(x_i | x_{i-1}) = p_\mathcal{N}(x_i | x_{i-1}, (t_i-t_{i-1})\sigma^2) \text{ for } i=1,2,\dots
$$
and with $x_0$ as before.

##### Dependence on the observed event times

The difference between the two priors will become most easily apparent by looking at the implied prior on the (log) hazard at (or right before) the censor time $t_\text{censor} = t_{N+1}$, for varying numbers of unique observed event times $N$. For the **random walk prior**, we'll have 
$$
x_j \sim \mathcal{N}(x_0, j\sigma^2) \text{ for } j = 1,\dots,N+1,
$$
while for the **Wiener process prior**, we'll have
$$
x_j \sim \mathcal{N}(0, t_j\sigma^2) \text{ for } j = 1,\dots,N+1.
$$
In particular, for $j=N+1$ (i.e. at censor time), we get a constant prior distribution for the Wiener process prior, but for the random walk prior we get a prior distribution that depends on the number of unique observed event times $N$. Similarly, even for fixed $N$, there is a (potentially strong) dependence of the implied prior for "interior" time slabs on the realization of the even times for the random walk prior, while there's "no" dependence of the implied prior for the Wiener process prior. *Caveat: There* will *actually be a dependence of the implied prior on the event time realizations also for the Wiener process, but this is only due to the piecewise-constant "assumption" and can be interpreted as an approximation error to the solution of the underlying stochastic differential equation.*




#### Posterior parameter and predictive plots

```{julia}
plot_summary(sim.lr2...; sim.df)
```

#### Reimplemented model

```{.julia include="../src/models.jl" snippet="pem_survival_model_randomwalk"} 
```

#### Original model

```{.stan include="survivalstan/pem_survival_model_randomwalk.stan"}
```

:::

### `pem_survival_model_timevarying`

::: {.panel-tabset}

#### Discussion

To be finished. To keep things short: 

* The original model has the same minor problems as the other models.
* While the original model implements a random walk prior on the *increments* of the covariate effects, I've kept things a bit simpler and instead just implemented the corresponding Wiener process prior on the *values* of the covariate effects. IMO, putting a given prior on the increments instead of on the values or vice versa is a *modeling decision*, and not a "mistake" by any stretch of the imagination. Doing one or the other implies different things, and which choice is "better" is not clear a priori and may depend on the setting.
* I believe sampling may have failed a bit for the run included in this notebook. I believe I have seen better sampling "runs", but as this doesn't have to be perfect, I've left it as is.

#### Posterior parameter and predictive plots

```{julia}
plot_summary(sim.lr3...; sim.df)
```

#### Reimplemented model

```{.julia include="../src/models.jl" snippet="pem_survival_model_timevarying"} 
```

#### Original model

```{.stan include="survivalstan/pem_survival_model_timevarying.stan"}
```

:::

## Addendum / Disclaimer

* I am aware that survivalstan hasn't been updated in the last 7 years (according to [github](https://github.com/hammerlab/survivalstan)). I have not implemented the above models to unearth any errors or write a competitor. I believe but haven't checked, that the "actual" models used by survivalstan are "more" correct. I was mainly curious whether I could do it, and I wanted to see how well [StanBlocks.jl](https://github.com/nsiccha/StanBlocks.jl) does.
* I've skipped the `pem_survival_model_gamma` model showcased at [https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html](https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html) because I did not understand why the widths of the timeslabs should affect the **shape** parameter of the Gamma prior. Only after implementing the time varying models did I discover the models at [https://nbviewer.org/github/hammerlab/survivalstan/blob/master/example-notebooks/Test%20new_gamma_survival_model%20with%20simulated%20data.ipynb](https://nbviewer.org/github/hammerlab/survivalstan/blob/master/example-notebooks/Test%20new_gamma_survival_model%20with%20simulated%20data.ipynb). Also, the ["Worked examples" page](https://jburos.github.io/survivalstan/Examples.html) lists a ["User-supplied PEM survival model with gammahazard"](https://jburos.github.io/survivalstan/examples/Test%20new_gamma_survival_model%20with%20simulated%20data.html), though for some reason it does not show up in the sidebar for either of the other examples, compare [https://jburos.github.io/survivalstan/examples/Example-using-pem_survival_model.html](https://jburos.github.io/survivalstan/examples/Example-using-pem_survival_model.html), [https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html](https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html), [https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_randomwalk%20with%20simulated%20data.html](https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_randomwalk%20with%20simulated%20data.html) and [https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_timevarying%20with%20simulated%20data.html](https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_timevarying%20with%20simulated%20data.html).